#!/usr/bin/env python

import clitool
import sys
import itertools
import re

MAGIC_CONVERT_RE=re.compile("(\d+)")

BIG_VALUE = 2 ** 60
SMALL_VALUE = - (2 ** 60)

def cmd(f):
    """Decorate a function with code to authenticate based on 1-2
    arguments past the normal number of arguments."""

    def g(*args, **kwargs):
        mc = args[0]
        n = f.func_code.co_argcount
        if len(args) > n:
            username = args[n]
            if len(args) > n + 1:
                password = args[n+1]
            else:
                password = ''
            mc.sasl_auth_plain(username, password)

        if kwargs['allBuckets']:
            buckets = mc.stats('bucket')
            for bucket in buckets.iterkeys():
                print '*' * 78
                print bucket
                print
                mc.bucket_select(bucket)
                f(*args[:n])
        else:
            f(*args[:n])

    return g

def stats_perform(mc, cmd=''):
    try:
        return mc.stats(cmd)
    except:
        print "Stats '%s' are not available from the requested engine." % cmd

def stats_formatter(stats, prefix=" ", cmp=None):
    if stats:
        longest = max((len(x) + 2) for x in stats.keys())
        for stat, val in sorted(stats.items(), cmp=cmp):
            s = stat + ":"
            print "%s%s%s" % (prefix, s.ljust(longest), val)

@cmd
def stats_vkey(mc, key, vb):
    cmd = "vkey %s %s" % (key, str(vb))
    vbs = mc.stats(cmd)
    print "verification for key", key
    stats_formatter(vbs)

@cmd
def stats_all(mc):
    stats_formatter(stats_perform(mc))

@cmd
def stats_raw(mc, arg):
    stats_formatter(mc.stats(arg))

def time_label(s):
    # -(2**64) -> '-inf'
    # 2**64 -> 'inf'
    # 0 -> '0'
    # 4 -> '4us'
    # 838384 -> '838ms'
    # 8283852 -> '8s'
    if s > BIG_VALUE:
        return 'inf'
    elif s < SMALL_VALUE:
        return '-inf'
    elif s == 0:
        return '0'
    product = 1
    sizes = (('us', 1), ('ms', 1000), ('s', 1000), ('m', 60))
    sizeMap = []
    for l,sz in sizes:
        product = sz * product
        sizeMap.insert(0, (l, product))
    lbl, factor = itertools.dropwhile(lambda x: x[1] > s, sizeMap).next()
    return "%d%s" % (s / factor, lbl)

@cmd
def stats_timings(mc):
    def seg(k, v):
        # Parse ('some_stat_x_y', 'v') into (('some_stat', x, y), v)
        ka = k.split('_')
        k = '_'.join(ka[0:-1])
        kstart, kend = [int(x) for x in ka[-1].split(',')]
        return ((k, kstart, kend), int(v))

    # Try to figure out the terminal width.  If we can't, 79 is good
    def termWidth():
        try:
            import fcntl, termios, struct
            h, w, hp, wp = struct.unpack('HHHH',
                                         fcntl.ioctl(0, termios.TIOCGWINSZ,
                                                     struct.pack('HHHH', 0, 0, 0, 0)))
            return w
        except:
            return 79

    # Acquire, sort, categorize, and label the timings.
    stats = sorted([seg(*kv) for kv in mc.stats('timings').items()])
    dd = {}
    totals = {}
    longest = 0
    for s in stats:
        lbl = "%s - %s" % (time_label(s[0][1]), time_label(s[0][2]))
        longest = max(longest, len(lbl) + 1)
        k = s[0][0]
        l = dd.get(k, [])
        l.append((lbl, s[1]))
        dd[k] = l
        totals[k] = totals.get(k, 0) + s[1]

    # Now do the actual output
    for k in sorted(dd):
        print " %s (%d total)" % (k, totals[k])
        widestnum = max(len(str(v[1])) for v in dd[k])
        ccount = 0
        for lbl,v in dd[k]:
            ccount += v
            pcnt = (ccount * 100.0) / totals[k]
            # This is the important part being printed
            toprint = "    %s: (%6.02f%%) %s" % (lbl.ljust(longest), pcnt,
                                                 str(v).rjust(widestnum))
            # Throw in a bar graph since they're popular and eye catchy
            remaining = termWidth() - len(toprint) - 2
            lpcnt = float(v) / totals[k]
            print "%s %s" % (toprint, '#' * int(lpcnt * remaining))

@cmd
def stats_tap(mc):
    stats_formatter(stats_perform(mc, 'tap'))

@cmd
def stats_tapagg(mc):
    stats_formatter(stats_perform(mc, 'tapagg _'))

@cmd
def stats_checkpoint(mc):
    stats_formatter(stats_perform(mc, 'checkpoint'))

@cmd
def stats_slabs(mc):
    stats_formatter(stats_perform(mc, 'slabs'))

@cmd
def stats_items(mc):
    stats_formatter(stats_perform(mc, 'items'))

def avg(s):
    return sum(s) / len(s)

def _maybeInt(x):
    try:
        return int(x)
    except:
        return x

def _magicConvert(s):
    return [_maybeInt(x) for x in MAGIC_CONVERT_RE.split(s)]

def magic_cmp(a, b):
    am = _magicConvert(a[0])
    bm = _magicConvert(b[0])
    return cmp(am, bm)

@cmd
def stats_hash(mc, with_detail=None):
    h = mc.stats('hash')
    with_detail = with_detail == 'detail'

    mins = []
    maxes = []
    counts = []
    for k,v in h.items():
        if 'max_dep' in k:
            maxes.append(int(v))
        if 'min_dep' in k:
            mins.append(int(v))
        if 'counted' in k:
            counts.append(int(v))
        if ':histo' in k:
            vb, kbucket = k.split(':')
            skey = 'summary:' + kbucket
            h[skey] = int(v) + h.get(skey, 0)

    h['avg_min'] = avg(mins)
    h['avg_max'] = avg(maxes)
    h['avg_count'] = avg(counts)
    h['min_count'] = min(counts)
    h['max_count'] = max(counts)
    h['total_counts'] = sum(counts)
    h['largest_min'] = max(mins)
    h['largest_max'] = max(maxes)

    toDisplay = h
    if not with_detail:
        toDisplay = {}
        for k in h:
            if 'vb_' not in k:
                toDisplay[k] = h[k]

    stats_formatter(toDisplay, cmp=magic_cmp)

@cmd
def stats_dispatcher(mc, with_logs='no'):
    with_logs = with_logs == 'logs'
    sorig = mc.stats('dispatcher')
    s = {}
    logs = {}
    slowlogs = {}
    for k,v in sorig.items():
        ak = tuple(k.split(':'))
        if ak[-1] == 'runtime':
            v = time_label(int(v))

        dispatcher = ak[0]

        for h in [logs, slowlogs]:
            if dispatcher not in h:
                h[dispatcher] = {}

        if ak[0] not in s:
            s[dispatcher] = {}

        if ak[1] in ['log', 'slow']:
            offset = int(ak[2])
            field = ak[3]
            h = {'log': logs, 'slow': slowlogs}[ak[1]]
            if offset not in h[dispatcher]:
                h[dispatcher][offset] = {}
            h[dispatcher][offset][field] = v
        else:
            field = ak[1]
            s[dispatcher][field] = v

    for dispatcher in sorted(s):
        print " %s" % dispatcher
        stats_formatter(s[dispatcher], "     ")
        for l,h in [('Slow jobs', slowlogs), ('Recent jobs', logs)]:
            if with_logs and h[dispatcher]:
                print "     %s:" % l
                for offset, fields in sorted(h[dispatcher].items()):
                    stats_formatter(fields, "        ")
                    print "        ---------"

@cmd
def reset(mc):
    mc.stats('reset')

def main():
    c = clitool.CliTool()

    c.addCommand('vkey', stats_vkey, 'vkey keyname vbid [username password]')
    c.addCommand('all', stats_all, 'all [username password]')
    c.addCommand('hash', stats_hash, 'hash [detail] [username password]')
    c.addCommand('checkpoint', stats_checkpoint, 'checkpoint [username password]')
    c.addCommand('tap', stats_tap, 'tap [username password]')
    c.addCommand('tapagg', stats_tapagg, 'tapagg [username password]')
    c.addCommand('timings', stats_timings, 'timings [username password]')
    c.addCommand('raw', stats_raw, 'raw argument [username password]')
    c.addCommand('reset', reset, 'reset [username password]')
    c.addCommand('dispatcher', stats_dispatcher, 'dispatcher [logs] [username password]')
    c.addCommand('slabs', stats_slabs, 'slabs [username password]')
    c.addCommand('items', stats_items, 'items [username password]')
    c.addFlag('-a', 'allBuckets', 'iterate over all buckets (requires admin u/p)')

    c.execute()

if __name__ == '__main__':
    main()
